#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Dataloader for sft, dpo, ppo
"""

import os
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Sequence, Union

import jsonlines
import torch
import torch.nn.functional as F
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer

DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]


def load_tokenized_dataset(
    dataset_paths: Union[PathType, List[PathType]], mode: str = "train", **kwargs
) -> Optional[DatasetType]:
    """
    Load pre-tokenized dataset.
    Each instance of dataset is a dictionary with
    `{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
    """
    if not dataset_paths:
        return None
    mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"})
    assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"

    if isinstance(dataset_paths, (str, os.PathLike)):
        dataset_paths = [dataset_paths]

    datasets = []  # `List[datasets.dataset_dict.Dataset]`
    for ds_path in dataset_paths:
        ds_path = os.path.abspath(ds_path)
        assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
        ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
        if isinstance(ds_dict, HFDataset):
            datasets.append(ds_dict)
        else:
            if mode_map[mode] in ds_dict:
                datasets.append(ds_dict[mode_map[mode]])
    if len(datasets) == 0:
        return None
    if len(datasets) == 1:
        return datasets.pop()
    return ConcatDataset(datasets=datasets)


@dataclass
class DataCollatorForSupervisedDataset(object):
    """
    Collate instances for supervised dataset.
    Each instance is a tokenized dictionary with fields
    `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
    """

    tokenizer: PreTrainedTokenizer
    max_length: int = 4096
    ignore_index: int = -100

    def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        """

        Args:
            instances (`Sequence[Dict[str, List[int]]]`):
                Mini-batch samples, each sample is stored in an individual dictionary.

        Returns:
            (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
                `input_ids`: `torch.Tensor` of shape (bsz, max_len);
                `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
                `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
        """
        assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
            f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
            f"but now `{self.tokenizer.pad_token_id}`"
        )

        # `List[torch.Tensor]`
        batch_input_ids = [
            (
                torch.LongTensor(instance["input_ids"][: self.max_length])
                if len(instance["input_ids"]) > self.max_length
                else torch.LongTensor(instance["input_ids"])
            )
            for instance in instances
        ]
        batch_labels = [
            (
                torch.LongTensor(instance["labels"][: self.max_length])
                if len(instance["labels"]) > self.max_length
                else torch.LongTensor(instance["labels"])
            )
            for instance in instances
        ]
        if self.tokenizer.padding_side == "right":
            input_ids = torch.nn.utils.rnn.pad_sequence(
                sequences=batch_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id,
            )  # (bsz, max_len)
            labels = torch.nn.utils.rnn.pad_sequence(
                sequences=batch_labels,
                batch_first=True,
                padding_value=self.ignore_index,
            )  # (bsz, max_len)
            # pad to max
            to_pad = self.max_length - input_ids.size(1)
            input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
            labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
        elif self.tokenizer.padding_side == "left":
            reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
            reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
                sequences=reversed_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id,
            )  # (bsz, max_len)
            input_ids = torch.flip(reversed_input_ids, dims=(1,))  # (bsz, max_len)
            reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
            reversed_labels = torch.nn.utils.rnn.pad_sequence(
                sequences=reversed_labels,
                batch_first=True,
                padding_value=self.ignore_index,
            )  # (bsz, max_len)
            labels = torch.flip(reversed_labels, dims=(1,))  # (bsz, max_len)
        else:
            raise RuntimeError(
                f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
                f"but now `{self.tokenizer.padding_side}`"
            )

        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)  # `torch.BoolTensor`, (bsz, max_len)

        return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)


@dataclass
class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset):
    def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        """

        Args:
            instances (`Sequence[Dict[str, List[int]]]`):
                Mini-batch samples, each sample is stored in an individual dictionary.

        Returns:
            (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
                `input_ids`: `torch.Tensor` of shape (bsz, max_len);
                `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
        """
        gt_answer = [ins.get("gt_answer", None) for ins in instances]
        instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
        ret = super().__call__(instances=instances)
        input_ids = F.pad(
            ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
        )
        attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
        return {"input_ids": input_ids, "attention_mask": attention_mask, "gt_answer": gt_answer}


@dataclass
class DataCollatorForPreferenceDataset(object):
    """
    Collate instances for supervised dataset.
    Each instance is a tokenized dictionary with fields
    `input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
    """

    tokenizer: PreTrainedTokenizer
    max_length: int = 4096

    def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        """

        Args:
            instances (`Sequence[Dict[str, List[int]]]`):
                Mini-batch samples, each sample is stored in an individual dictionary.

        Returns:
            (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
                `input_ids`: `torch.Tensor` of shape (bsz, max_len);
                `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
                `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
        """
        assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
            f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
            f"but now `{self.tokenizer.pad_token_id}`"
        )

        (
            chosen_input_ids,
            chosen_loss_mask,  # [batch_size * seq_len]
            reject_input_ids,
            reject_loss_mask,
        ) = (
            chuncate_sequence([ins["chosen_input_ids"] for ins in instances], self.max_length, torch.int64),
            chuncate_sequence([ins["chosen_loss_mask"] for ins in instances], self.max_length, torch.bool),
            chuncate_sequence([ins["rejected_input_ids"] for ins in instances], self.max_length, torch.int64),
            chuncate_sequence([ins["rejected_loss_mask"] for ins in instances], self.max_length, torch.bool),
        )

        padding_side = self.tokenizer.padding_side
        chosen_attention_mask = [torch.ones_like(seq).bool() for seq in chosen_input_ids]
        reject_attention_mask = [torch.ones_like(seq).bool() for seq in reject_input_ids]

        (
            chosen_input_ids,
            chosen_attention_mask,
            chosen_loss_mask,
            reject_input_ids,
            reject_attention_mask,
            reject_loss_mask,
        ) = (
            pad_to_max_len(chosen_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side),
            pad_to_max_len(chosen_attention_mask, self.max_length, False, padding_side=padding_side),
            pad_to_max_len(chosen_loss_mask, self.max_length, False, padding_side=padding_side),
            pad_to_max_len(reject_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side),
            pad_to_max_len(reject_attention_mask, self.max_length, False, padding_side=padding_side),
            pad_to_max_len(reject_loss_mask, self.max_length, False, padding_side=padding_side),
        )

        return dict(
            chosen_input_ids=chosen_input_ids,
            chosen_attention_mask=chosen_attention_mask,
            chosen_loss_mask=chosen_loss_mask,
            reject_input_ids=reject_input_ids,
            reject_attention_mask=reject_attention_mask,
            reject_loss_mask=reject_loss_mask,
        )


@dataclass
class DataCollatorForKTODataset(object):
    """
    Collate instances for kto dataset.
    Each input instance is a tokenized dictionary with fields
    `prompt`(List[int]), `completion`(List[int]) and `label`(bool).
    Each output instance is a tokenized dictionary with fields
    `kl_input_ids`(List[int]), `kl_attention_mask`(List[int]) and `kl_loss_mask`(List[int]).
    `input_ids`(List[int]), `attention_mask`(List[int]), `loss_mask`(List[int]) and `label`(bool).
    """

    tokenizer: PreTrainedTokenizer
    max_length: int = 4096
    ignore_index: int = -100

    def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        """

        Args:
            instances (`Sequence[Dict[str, List[int]]]`):
                Mini-batch samples, each sample is stored in an individual dictionary contains the following fields:
                `prompt`(List[int]), `completion`(List[int]) and `label`(bool, if the sample is desirable or not).

        Returns:
            (`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
                `input_ids`: `torch.Tensor` of shape (bsz, max_len);
                `attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
                `labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
        """
        assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
            f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
            f"but now `{self.tokenizer.pad_token_id}`"
        )
        # prepare the preference data
        prompt = [torch.LongTensor(instance["prompt"]) for instance in instances]
        prompt_zeros = [torch.zeros_like(t) for t in prompt]
        completion = [torch.LongTensor(instance["completion"]) for instance in instances]
        completion_ones = [torch.ones_like(t) for t in completion]
        label = [torch.tensor(instance["label"], dtype=torch.bool) for instance in instances]
        input_ids = [torch.cat([prompt[i], completion[i]], dim=-1) for i in range(len(instances))]
        loss_mask = [torch.cat([prompt_zeros[i], completion_ones[i]], dim=-1) for i in range(len(instances))]
        # right padding
        input_ids = torch.nn.utils.rnn.pad_sequence(
            sequences=input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id,
        )  # (bsz, max_len)
        loss_mask = torch.nn.utils.rnn.pad_sequence(
            sequences=loss_mask, batch_first=True, padding_value=0
        )  # (bsz, max_len)
        to_pad = self.max_length - input_ids.size(1)
        input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
        loss_mask = F.pad(loss_mask, (0, to_pad), value=0)
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)  # `torch.BoolTensor`, (bsz, max_len)

        # prepare kt data
        kl_completion = completion[::-1]  # y'
        kl_completion_ones = [torch.ones_like(t) for t in kl_completion]
        kl_input_ids = [torch.cat([prompt[i], kl_completion[i]], dim=-1) for i in range(len(instances))]
        kl_loss_mask = [torch.cat([prompt_zeros[i], kl_completion_ones[i]], dim=-1) for i in range(len(instances))]
        # right padding
        kl_input_ids = torch.nn.utils.rnn.pad_sequence(
            sequences=kl_input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id,
        )  # (bsz, max_len)
        kl_loss_mask = torch.nn.utils.rnn.pad_sequence(
            sequences=kl_loss_mask, batch_first=True, padding_value=0
        )  # (bsz, max_len)
        to_pad = self.max_length - kl_input_ids.size(1)
        kl_input_ids = F.pad(kl_input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
        kl_loss_mask = F.pad(kl_loss_mask, (0, to_pad), value=0)
        kl_attention_mask = kl_input_ids.ne(self.tokenizer.pad_token_id)  # `torch.BoolTensor`, (bsz, max_len)
        data_dict = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "loss_mask": loss_mask,
            "label": torch.stack(label),
            "kl_input_ids": kl_input_ids,
            "kl_attention_mask": kl_attention_mask,
            "kl_loss_mask": kl_loss_mask,
        }
        return data_dict


class StatefulDistributedSampler(DistributedSampler):
    def __init__(
        self,
        dataset: Dataset,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
        seed: int = 0,
        drop_last: bool = False,
    ) -> None:
        super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
        self.start_index: int = 0

    def __iter__(self) -> Iterator:
        iterator = super().__iter__()
        indices = list(iterator)
        indices = indices[self.start_index :]
        return iter(indices)

    def __len__(self) -> int:
        return self.num_samples - self.start_index

    def set_start_index(self, start_index: int) -> None:
        self.start_index = start_index


def apply_chat_template_and_mask(
    tokenizer: PreTrainedTokenizer,
    chat: List[Dict[str, str]],
    max_length: Optional[int] = None,
    system_prompt: str = None,
    padding: bool = True,
    truncation: bool = True,
    ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:

    if system_prompt is None:
        system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"

    system_element = {
        "role": "system",
        "content": system_prompt,
    }

    # Format for RL.
    if "messages" in chat:
        gt_answer = chat.get("gt_answer", None)
        test_cases = chat.get("test_cases", None)
        chat = [chat["messages"]]

    tokens = []
    assistant_mask = []
    for i, msg in enumerate(chat):
        msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
        # remove unexpected bos token
        if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
            msg_tokens = msg_tokens[1:]
        tokens.extend(msg_tokens)
        if msg["role"] == "assistant":
            assistant_mask.extend([True] * len(msg_tokens))
        else:
            assistant_mask.extend([False] * len(msg_tokens))
    attention_mask = [1] * len(tokens)
    if max_length is not None:
        if padding and len(tokens) < max_length:
            to_pad = max_length - len(tokens)
            # Left padding for generation.
            tokens = [tokenizer.pad_token_id] * to_pad + tokens
            assistant_mask = [False] * to_pad + assistant_mask
            attention_mask = [0] * to_pad + attention_mask
        if truncation and len(tokens) > max_length:
            tokens = tokens[:max_length]
            assistant_mask = assistant_mask[:max_length]
            attention_mask = attention_mask[:max_length]
    input_ids = torch.tensor(tokens, dtype=torch.long)
    attention_mask = torch.tensor(attention_mask, dtype=torch.long)
    labels = input_ids.clone()
    labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx

    if gt_answer is not None:
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
    elif test_cases is not None:
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "test_cases": test_cases,
        }
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }


class RawConversationDataset(Dataset):
    """
    Raw conversation dataset.
    Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
    """

    def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int, system_prompt: str) -> None:
        self.tokenizer = tokenizer
        self.raw_texts = []
        with jsonlines.open(input_file) as f:
            for line in f:
                self.raw_texts.append(line)
        self.tokenized_texts = [None] * len(self.raw_texts)
        self.max_length = max_length
        self.system_prompt = system_prompt

    def __len__(self) -> int:
        return len(self.raw_texts)

    def __getitem__(self, index: int):
        if self.tokenized_texts[index] is None:
            message = self.raw_texts[index]
            tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
            self.tokenized_texts[index] = dict(tokens)
        return self.tokenized_texts[index]


def collate_fn_grpo(batch):
    input_ids = [item["input_ids"] for item in batch]
    attention_mask = [item["attention_mask"] for item in batch]
    labels = [item["labels"] for item in batch]
    # Assume input_ids, attention_mask, labels are already of the same length,
    # otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    input_ids = torch.stack(input_ids)
    attention_mask = torch.stack(attention_mask)
    labels = torch.stack(labels)
    ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
    if "test_cases" in batch[0]:
        ret["test_cases"] = [item["test_cases"] for item in batch]
    if "gt_answer" in batch[0]:
        ret["gt_answer"] = [item["gt_answer"] for item in batch]
    return ret
